Purpose: compare normalization methods for a particular dataset

source(file.path("scripts", "CZI_functions.R"))

if (!("Rtsne" %in% installed.packages())) {
   install.packages("Rtsne")
}
if (!("NMI" %in% installed.packages())) {
   install.packages("NMI")
}
if (!("caret" %in% installed.packages())) {
   install.packages("caret")
}
# Get the file names of all the normalized files
normalized.files <- dir("arnon_data/normalized_arnon")

# Read in each of the normalization files
normalized.data <- lapply(normalized.files, function(file) {
                              readr::read_tsv(file.path("arnon_data/normalized_arnon",
                                                        file))
                          })
## Parsed with column specification:
## cols(
##   .default = col_double(),
##   gene = col_character()
## )
## See spec(...) for full column specifications.
## Parsed with column specification:
## cols(
##   .default = col_double(),
##   Genes = col_character()
## )
## See spec(...) for full column specifications.
## Parsed with column specification:
## cols(
##   .default = col_double(),
##   Genes = col_character()
## )
## See spec(...) for full column specifications.
## Parsed with column specification:
## cols(
##   .default = col_double(),
##   Genes = col_character()
## )
## See spec(...) for full column specifications.
# Keep all the gene lists
genes <- lapply(normalized.data, function(x) x[,1])

# Keep the names with it.
names(normalized.data) <- gsub("\\.tab|arnon_", "", normalized.files)
# Read in the data
meta <- readr::read_csv(file.path("arnon_data", "meta_data.csv"))
## Parsed with column specification:
## cols(
##   .default = col_character(),
##   channel_count = col_double(),
##   taxid_ch1 = col_double(),
##   data_row_count = col_double()
## )
## See spec(...) for full column specifications.
# Get sample names from columns
samples <- colnames(normalized.data[[1]])[-1]

# Keep metadata only for the samples we have
meta <- meta[match(samples, meta$title), ]

# Extra cell types info as it's own vector
cell.types <- as.factor(meta$cell.type.ch1)
plate.batch <- as.factor(meta$cohort.ch1)
# Run tsne on each dataset and extract
tsne <- lapply(normalized.data, function(dat) {
                        tsne.res <- Rtsne::Rtsne(t(dat[,-1]),
                                                 check_duplicates = FALSE)
                        tsne.res <- tsne.res$Y
                        names(tsne.res) <- samples
                        return(tsne.res)
                        })
tsne.plot <- function(dat, var, name = "name"){
  colz <- colors(distinct = TRUE)[runif(length(levels(var)), min = 1,
                                        max = length(colors(distinct=TRUE)))]
  plot(dat,pch = 21, bg = colz[var], main = name);
  legend(x = "bottomleft", legend = levels(var), fill = colz, cex = 0.6)
}

# Plot em with cell type and plate batch labels
lapply(tsne, function(dataset) {
      # Get normalization method
      set.name <- names(tsne)[parent.frame()$i[]]
      
      # Make plots for cell type and plate batch 
      cell.type.plot <- tsne.plot(dataset, var = cell.types, name = set.name)
      plate.batch.plot <- tsne.plot(dataset, var = plate.batch, name = set.name)
    
      # Save the plots to pngs
      png(paste0("tsne_", set.name, "cell_type.png"))
      cell.type.plot
      dev.off()
      
      png(paste0("tsne_", set.name, "plate_batch.png"))
      plate.batch.plot
      dev.off()
      
      # Print out the plots in the Rmd
      cell.type.plot
      plate.batch.plot
})

## $counts
## $counts$rect
## $counts$rect$w
## [1] 8.802387
## 
## $counts$rect$h
## [1] 14.04243
## 
## $counts$rect$left
## [1] -48.05311
## 
## $counts$rect$top
## [1] -50.55011
## 
## 
## $counts$text
## $counts$text$x
## [1] -44.19424 -44.19424
## 
## $counts$text$y
## [1] -55.23093 -59.91174
## 
## 
## 
## $log2
## $log2$rect
## $log2$rect$w
## [1] 10.0968
## 
## $log2$rect$h
## [1] 11.62271
## 
## $log2$rect$left
## [1] -54.14677
## 
## $log2$rect$top
## [1] -42.40292
## 
## 
## $log2$text
## $log2$text$x
## [1] -49.72044 -49.72044
## 
## $log2$text$y
## [1] -46.27715 -50.15139
## 
## 
## 
## $TMM
## $TMM$rect
## $TMM$rect$w
## [1] 10.53523
## 
## $TMM$rect$h
## [1] 12.17734
## 
## $TMM$rect$left
## [1] -48.38835
## 
## $TMM$rect$top
## [1] -45.10525
## 
## 
## $TMM$text
## $TMM$text$x
## [1] -43.76982 -43.76982
## 
## $TMM$text$y
## [1] -49.16436 -53.22347
## 
## 
## 
## $voom
## $voom$rect
## $voom$rect$w
## [1] 11.38769
## 
## $voom$rect$h
## [1] 12.27284
## 
## $voom$rect$left
## [1] -52.13407
## 
## $voom$rect$top
## [1] -37.52613
## 
## 
## $voom$text
## $voom$text$x
## [1] -47.14183 -47.14183
## 
## $voom$text$y
## [1] -41.61708 -45.70803
kmeans_eval <- function(feature, metadata = metadata, iter = 10, seed =1234) {
  # This function is used to perform iterative k-means clustering based on projected 
  # features for single cell data and then evaluate the performance according to 
  # Nomalized mutual information (NMI) and adjusted rand index (ARI)
  #
  # Args:
  #  feature: a data.frame contains projected features (n dimensional space, n = 2, 3 ...), the columns are
  #    projected features in n dimensional space for a sample (cell), rows are samples
  #  celltype: vector contains cell type informantion
  #  iter: number of interation for k-means clustering
  #  seed: seed for k-means clustering
  # Returns:
  #   NMI and ARI results
  # convert celltype into numbers
  metadata_num <- as.numeric(factor(metadata))
  sample_id <- seq(1:nrow(feature))
  
  # iterative k-means
  nmi_score_all <- c()
  ari_score_all <- c()
  all_cluster <- list()
  
  for(i in 1:iter){
    # set k equal to the number of celltypes in the dataset
    k <- length(unique(metadata))
    
    # perform k means clustering
    km <- kmeans(feature, k)

    # true clusters
    orignal_data <- data.frame(sample_id, metadata_num)
    
    # predicted clusters
    cl_data <- data.frame(sample_id, km$cluster)
    
    # calculate NMI and ARI score
    nmi_score <- NMI::NMI(orignal_data, cl_data)$value
    ari_score <- mclust::adjustedRandIndex(km$cluster, metadata_num)
    
    nmi_score_all <- c(nmi_score_all, nmi_score)
    ari_score_all <- c(ari_score_all, ari_score)
  }

  # Compile all results into a data.frame
  results <- data.frame(ari = ari_score_all, nmi = nmi_score_all)
  return(results)
}
knn_eval <- function(feature, metadata = metadata, k = 10){
  # This function performs knn based evaluation 
  # Args:
  #  feature: a data.frame contains projected features (n dimensional space,
  #           n = 2, 3 ...), the columns are
  #    projected features in n dimensional space for a sample (cell), rows are samples
  #  celltype: vector contains cell type information
  #  k: cross validation fold
  # Returns:
  #  list of accuracy scores for each iteration of KNN
  # Make the data into a data.frame:
  feature <- data.frame("tsne" = feature, "metadata" = metadata)
  
  # Split observations into groups
  cv <- cvTools::cvFolds(nrow(feature), K = k, R = 1)
  
  # Create empty objects to store the performance information for each iteration
  perf.eval <- list()
  confusion.matrix <- 0
  
  # Go through this iteration that k times
  for (i in 1:k) {
    # Isolate samples for training the model
    train <- feature[cv$subsets[-which(cv$which == i)], ]
    
    # Isolate samples for testing the model
    test <- feature[cv$subsets[which(cv$which == i)], ]
    
    # Perform KNN model fitting
    knn.fit <- caret::train(metadata~. , data = train, method = "knn",
                     trControl = caret::trainControl(method = "cv", number = 3),
                     preProcess = c("center", "scale"),
                     tuneLength = 10)
    
    # Evaluate the model
    knn.pred <- predict(knn.fit, newdata = subset(test, select = -c(metadata)))
    perf.eval[[i]] <- round(cal_performance(knn.pred, test$metadata, 3), 2)
    
    # Make the results into a matrix
    matrix <- as.matrix(table(test$metadata, knn.pred, deparse.level = 0))
    confusion.matrix <- matrix + confusion.matrix
  }
  
  # Get mean performance of cross validation
  perf.eval <- dplyr::bind_rows(perf.eval)
  accuracy <- perf.eval$accuracy

  return(data.frame("knn" = accuracy))
}
# Get knn and kmeans results for all tsne's of all datasets
cell.type.results <- lapply(tsne, function(dataset) {
                        # Get clustering results of the data
                        knn.results <- knn_eval(dataset, metadata = cell.types)
                        kmeans.results <- kmeans_eval(dataset, metadata = cell.types)
                            
                        # Return data frame of combined results
                        data.frame(knn.results, kmeans.results)
                            })
## Loading required package: lattice
## Loading required package: ggplot2
## Warning: did not converge in 10 iterations
# Get knn and kmeans results for all tsne's of all datasets
batch.results <- lapply(tsne, function(dataset) {
                        # Get clustering results of the data  
                        knn.results <- knn_eval(dataset, metadata = plate.batch)
                        kmeans.results <- kmeans_eval(dataset,
                                                      metadata = plate.batch)
                          
                        # Return data frame of combined results  
                        data.frame(knn.results, kmeans.results)
                        })
plot.results <- function(results.list, name = "results") {
  # This function makes a boxplot of the cluster statistic results 
  # Args:
  #  results.list: a list of dataframes which each column contains a different 
  #                statistic for 10 iterations
  #  name: name to use for the png to be saved and the plot title 
  # Returns: boxplots of the normalization method cluster statistics. Prints the
  #          plots and save the plot as a png
  # Get meta info 
  meta <- names(unlist(results.list))

  # Transform list into a dataframe:
  ggplot.df <- data.frame("method" = stringr::word(meta, 1, sep = "\\."),
                          "test" = gsub("[0-9]*$", "", 
                                               stringr::word(meta, 2, sep = "\\.")),
                          "iter" = rep(1:10, length(results.list)*3),
                          "values" = unlist(results.list))
  
  # Make the plot
  plot <- ggplot(data = ggplot.df, aes(x = method, y = values, fill = test)) +
          geom_boxplot(position = position_dodge()) + 
          xlab("Normalization method") +
          ggtitle(name) +
          facet_wrap(~test)

  # Save plot to png
  ggsave(paste0(name, "cluster_results.png"), width = 10) 
  
  # Print plot
  plot
}
# Plot the cell type results 
plot.results(cell.type.results, name = "Cell_type")
## Saving 10 x 5 in image

# Plot the plate batch results
plot.results(batch.results, name = "Plate_batch")
## Saving 10 x 5 in image